import sys
import cv2
import numpy as np
import stereoconfig
# import pcl
# import pcl.pcl_visualization

# =================预处理
def preprocess(img1, img2):
    # 彩色图-> 灰度图
    if(img1.ndim == 3):
        img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) # 通过opencv加载的图像通道顺序是BGR
    if(img2.ndim == 3):
        img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)

    # 直方图均衡
    img1 = cv2.equalizeHist(img1)
    img2 = cv2.equalizeHist(img2)

    return img1, img2


# ================消除畸变
def undistortion(image, camera_matrix, dist_coeff):
    # camera_matrix : 相机内参矩阵， dist_coeff : 相机畸变矩阵
    undistortion_image = cv2.undistort(image, camera_matrix, dist_coeff)

    return undistortion_image


# ===================获取畸变校正和立体校正的映射变换矩阵、重投影矩阵
# @param：config是一个类，存储着双目标定的参数:config = stereoconfig.stereoCamera()
def getRectifyTransform(height, width, config):
    # 读取内参和外参
    left_K = config.cam_matrix_left
    right_K = config.cam_matrix_right
    left_distortion = config.distortion_l
    right_distortion = config.distortion_r
    R = config.R
    T = config.T

    # 计算校正变换
    height = int(height)
    width = int(width)
    # R1：左摄像机旋转矩阵，P1：左摄像机投影矩阵，Q：深度差异映射矩阵
    R1, R2, P1, P2, Q, roi1, roi2 = cv2.stereoRectify(left_K, left_distortion, right_K, right_distortion,
                                                         (width, height), R, T, alpha=0)
    # 校正查找映射表，将原始图像和校正后的图像上的点一一对应起来
    # 计算校正投影矩阵
    # map1:输出的X坐标重映射参数
    # map2:输出的Y坐标重映射参数
    map1x, map1y = cv2.initUndistortRectifyMap(left_K, left_distortion, R1, P1, (width, height), cv2.CV_32FC1)
    map2x, map2y = cv2.initUndistortRectifyMap(right_K, right_distortion, R2, P2, (width, height), cv2.CV_32FC1)

    return map1x, map1y, map2x, map2y, Q


# =======================畸变校正和立体校正
def rectifyImage(image1, image2, map1x, map1y, map2x, map2y):
    rectified_img1 = cv2.remap(image1, map1x, map1y, cv2.INTER_AREA)
    rectified_img2 = cv2.remap(image2, map2x, map2y, cv2.INTER_AREA)
    return rectified_img1, rectified_img2


# ========================立体校正检验---画线
def draw_line(image1, image2):
    # 建立输出图像
    height = max(image1.shape[0], image2.shape[0])
    width = image1.shape[1] + image2.shape[1]

    output = np.zeros((height, width, 3), dtype=np.uint8)
    output[0:image1.shape[0], 0:image1.shape[1]] = image1
    output[0:image2.shape[0], image1.shape[1]:] = image2

    # 绘制等间距平行线
    line_interval = 50 # 直线间隔：50
    for k in range(height//line_interval):
        cv2.line(output, (0, line_interval * (k + 1)), (2 * width, line_interval * (k + 1)), (0, 255, 0), thickness=2,
                      lineType=cv2.LINE_AA)
    return output


# =================视差计算
def stereoMatchSGBM(left_image, right_image, down_scale=False):
    # SGBM匹配参数设置
    if left_image.ndim == 2:
        img_channels =1
    else:
        img_channels = 3
    blockSize = 3
    paraml = {'minDisparity': 0, # Minimum possible disparity value
              'numDisparities': 128, # Maximum disparity minus minimum disparity
              'blockSize': blockSize, # Matched block size
              'P1': 8 * img_channels * blockSize ** 2, # The first parameter controlling the disparity smoothness
              'P2': 32 * img_channels * blockSize ** 2, # The second parameter controlling the disparity smoothness
              'disp12MaxDiff': 1,    # 左右视差检查中允许的最大差异（以整数像素为单位）
              'preFilterCap': 63,    # 预滤波图像像素的截断值
              'uniquenessRatio': 15,   #通常，5-15范围内的值就足够了
              'speckleWindowSize': 100,  # 平滑视差区域的最大尺寸，以考虑其噪声斑点和无效。将其设置为0可禁用斑点过滤。否则，将其设置在50-200的范围内
              'speckleRange': 1,         # 每个连接组件内的最大视差变化。
              'mode': cv2.STEREO_SGBM_MODE_SGBM_3WAY  #
             }

    # 构建SGBM对象
    left_matcher = cv2.StereoSGBM_create(**paraml)
    paramr = paraml
    paramr['minDisparity'] = -paraml['numDisparities']
    right_matcher = cv2.StereoSGBM_create(**paramr)

    # 计算视差图
    size = (left_image.shape[1], left_image.shape[0])
    if down_scale == False:
        disparity_left = left_matcher.compute(left_image, right_image)
        disparity_right = right_matcher.compute(right_image, left_image)

    else:
        # 实现高斯金字塔中的下采样,抛弃偶数行和偶数列
        left_image_down = cv2.pyrDown(left_image)
        right_image_down = cv2.pyrDown(right_image)
        factor = left_image.shape[1] / left_image_down.shape[1]

        disparity_left_half = left_matcher.compute(left_image_down, right_image_down)
        disparity_right_half = right_matcher.compute(right_image_down, left_image_down)
        disparity_left = cv2.resize(disparity_left_half, size, interpolation=cv2.INTER_AREA)
        disparity_right = cv2.resize(disparity_right_half, size, interpolation=cv2.INTER_AREA)
        disparity_left = factor * disparity_left
        disparity_right = factor * disparity_right

    # 真实视差（因为SGBM算法得到的视差是x16的）
    trueDisp_left = disparity_left.astype(np.float32) /16
    trueDisp_right = disparity_right.astype(np.float32) /16

    return trueDisp_left, trueDisp_right


# ======================得到视差图后，就可以计算像素深度
# ===利用opencv函数计算深度图
def getDepthMapWithQ(disparityMap: np.ndarray, Q: np.array) -> np.array:
    """
    :param disparityMap:输入视差图
    :param Q: 输入，重投影矩阵Q是一个4*4的视差图到深度图的映射矩阵（disparity-to-depth mapping matrix）
    :return:
    # z方向坐标计算偏差很大，有-13367.69mm，暂时不用函数计算
    """
    # cv2.reprojectImageTo3D将像素坐标转换为三维坐标，该函数会返回一个3通道的矩阵，分别存储X、Y、Z坐标（左摄像机坐标下）
    points_3d = cv2.reprojectImageTo3D(disparityMap, Q, handleMissingValues=False, ddepth=-1)  # 单位是毫米（mm）
    # x, y, depth = cv2.split(points_3d)
    depthMap = points_3d[:, :, 2]   # 只取Z方向的坐标
    # ...将视差值小于0或者大于65535的值置0
    reset_index = np.where(np.logical_or(depthMap > 0.0, depthMap < -65535.0))
    depthMap[reset_index] = 0

    return depthMap.astype(np.float32)


# ===根据公式计算深度图， depth = (f*b)/[d + (Cxr-Cxl)]
# 其中 f 为焦距长度（像素焦距），b为基线长度，d为视差，c_{xl}与c_{xr}为两个相机主点的列坐标
def getDepthMapWithConfig(disparityMap : np.ndarray, config : stereoconfig.stereoCamera) -> np.ndarray:
    #fb = config.cam_matrix_left[0, 0] * (-config.T[0])
    fb = config.cam_matrix_left[0, 0] * (config.T[0])
    doffs = config.doffs
    depthMap = np.divide(fb, disparityMap + doffs)
    # problem
    #reset_index = np.where(np.logical_or(depthMap > 0.0, depthMap < -65535.0))
    reset_index = np.where(np.logical_or(depthMap < 0.0, depthMap > 65535.0))
    depthMap[reset_index] = 0
    # reset_index2 = np.where(disparityMap < 0.0)
    # depthMap[reset_index2] = 0

    return depthMap.astype(np.float32)



if __name__ == '__main__':
    # ===读取MiddleBurry数据集图片
    iml = cv2.imread('img/A0.png')   # 左图
    imr = cv2.imread('img/A1.png')  # 右图
    if(iml is None) or (imr is None):
        print("Error: Images are empty, please check your image's path!")
        sys.exit(0)
    height, width = iml.shape[0:2]

    # ===读取相机内参和外参
    # 使用之前先将标定得到的内外参数填写到stereoconfig.py中的stereoCamera类中
    config = stereoconfig.stereoCamera()  # 类实例化
    config.setMiddleBurryParams_Adirondack()
    print(config.cam_matrix_left)

    # ===立体校正
    map1x, map1y, map2x, map2y, Q = getRectifyTransform(height, width, config)  # 获取用于畸变校正和立体校正的映射矩阵以及用于计算像素空间坐标的重投影矩阵
    iml_rectified, imr_rectified = rectifyImage(iml, imr, map1x, map1y, map2x, map2y)
    print("打印重投影矩阵")
    print(Q)

    # ===绘制等间距平行线，检查立体校正的效果
    line = draw_line(iml_rectified, imr_rectified)
    cv2.imwrite('./data/check_rectificationAdirondackfu.png', line)

    # ===立体匹配
    iml_, imr_ = preprocess(iml, imr) # 预处理，一般可以削弱光照不均的影响，不做也可以
    # 获取两幅图像的视差值
    disp, _ = stereoMatchSGBM(iml, imr, False) # 这里传入的是未经校正的图像，因为数据集已校正过
    mindisp = np.min(disp)
    maxdisp = np.max(disp)
    #cv2.imwrite('./data/disparityAdirondackfu.png', disp*4)
    cv2.imwrite('./data/disparity.png', disp * 4)

    # ===计算深度图
    #depthMap = getDepthMapWithQ(disp, Q)
    depthMap = getDepthMapWithConfig(disp, config)
    minDepth = np.min(depthMap)
    maxDepth = np.max(depthMap)
    print(minDepth, maxDepth)
    # 取值范围0~255
    depthMapVis = (255.0 * (depthMap - minDepth)) / (maxDepth - minDepth)
    mindMV = np.min(depthMapVis)
    maxdMV = np.max(depthMapVis)
    depthMapVis = depthMapVis.astype(np.uint8)
    # cv2.imshow("DepthMap", depthMapVis)
    cv2.imwrite('./data/DepthMapAdirondackfu.png', depthMapVis)
    #cv2.waitKey(0)